package com.onelogin.saml2.util; import java.io.BufferedInputStream; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.FileNotFoundException; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.io.StringReader; import java.io.UnsupportedEncodingException; import java.net.URL; import java.net.URLDecoder; import java.net.URLEncoder; import java.nio.charset.Charset; import java.security.GeneralSecurityException; import java.security.InvalidKeyException; import java.security.KeyFactory; import java.security.NoSuchAlgorithmException; import java.security.NoSuchProviderException; import java.security.Key; import java.security.PrivateKey; import java.security.Signature; import java.security.SignatureException; import java.security.cert.CertificateException; import java.security.cert.CertificateFactory; import java.security.cert.X509Certificate; import java.security.spec.PKCS8EncodedKeySpec; import java.util.Calendar; import java.util.Iterator; import java.util.Locale; import java.util.TimeZone; import java.util.UUID; import java.util.zip.Deflater; import java.util.zip.DeflaterOutputStream; import java.util.zip.Inflater; import javax.crypto.KeyGenerator; import javax.crypto.SecretKey; import javax.xml.namespace.NamespaceContext; import javax.xml.parsers.DocumentBuilder; import javax.xml.parsers.DocumentBuilderFactory; import javax.xml.parsers.ParserConfigurationException; import javax.xml.parsers.SAXParser; import javax.xml.parsers.SAXParserFactory; import javax.xml.transform.Source; import javax.xml.transform.dom.DOMSource; import javax.xml.validation.Schema; import javax.xml.validation.Validator; import javax.xml.xpath.XPath; import javax.xml.xpath.XPathConstants; import javax.xml.xpath.XPathExpression; import javax.xml.xpath.XPathExpressionException; import javax.xml.xpath.XPathFactory; import javax.xml.XMLConstants; import org.apache.commons.codec.binary.Base64; import org.apache.commons.codec.digest.DigestUtils; import org.apache.commons.lang3.StringUtils; import org.apache.xml.security.encryption.EncryptedData; import org.apache.xml.security.encryption.EncryptedKey; import org.apache.xml.security.encryption.XMLCipher; import org.apache.xml.security.exceptions.XMLSecurityException; import org.apache.xml.security.keys.KeyInfo; import org.apache.xml.security.signature.XMLSignature; import org.apache.xml.security.transforms.Transforms; import org.apache.xml.security.utils.XMLUtils; import org.joda.time.DateTime; import org.joda.time.DateTimeZone; import org.joda.time.Period; import org.joda.time.format.DateTimeFormat; import org.joda.time.format.DateTimeFormatter; import org.joda.time.format.ISOPeriodFormat; import org.joda.time.format.PeriodFormatter; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.w3c.dom.Attr; import org.w3c.dom.Document; import org.w3c.dom.Element; import org.w3c.dom.Node; import org.w3c.dom.NodeList; import org.xml.sax.InputSource; import org.xml.sax.SAXException; import com.onelogin.saml2.exception.ValidationError; import com.onelogin.saml2.exception.XMLEntityException; /** * Util class of OneLogin's Java Toolkit. * * A class that contains several auxiliary methods related to the SAML protocol */ public final class Util { /** * Private property to construct a logger for this class. */ private static final Logger LOGGER = LoggerFactory.getLogger(Util.class); private static final DateTimeFormatter DATE_TIME_FORMAT = DateTimeFormat.forPattern("yyyy-MM-dd'T'HH:mm:ss'Z'").withZone(DateTimeZone.UTC); private static final DateTimeFormatter DATE_TIME_FORMAT_MILLS = DateTimeFormat.forPattern("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'").withZone(DateTimeZone.UTC); public static final String UNIQUE_ID_PREFIX = "ONELOGIN_"; public static final String RESPONSE_SIGNATURE_XPATH = "/samlp:Response/ds:Signature"; public static final String ASSERTION_SIGNATURE_XPATH = "/samlp:Response/saml:Assertion/ds:Signature"; /** Indicates if JAXP 1.5 support has been detected. */ private static boolean JAXP_15_SUPPORTED = isJaxp15Supported(); private Util() { //not called } /** * Method which uses the recommended way ( https://docs.oracle.com/javase/tutorial/jaxp/properties/error.html ) * of checking if JAXP >= 1.5 options are supported. Needed if the project which uses this library also has * Xerces in it's classpath. * * If for whatever reason this method cannot determine if JAXP 1.5 properties are supported it will indicate the * options are supported. This way we don't accidentally disable configuration options. * * @return */ public static boolean isJaxp15Supported() { boolean supported = true; try { SAXParserFactory spf = SAXParserFactory.newInstance(); SAXParser parser = spf.newSAXParser(); parser.setProperty("http://javax.xml.XMLConstants/property/accessExternalDTD", "file"); } catch (SAXException ex) { String err = ex.getMessage(); if (err.contains("Property 'http://javax.xml.XMLConstants/property/accessExternalDTD' is not recognized.")) { //expected, jaxp 1.5 not supported supported = false; } } catch (Exception e) { LOGGER.info("An exception occurred while trying to determine if JAXP 1.5 options are supported.", e); } return supported; } /** * This function load an XML string in a save way. Prevent XEE/XXE Attacks * * @param xml * String. The XML string to be loaded. * * @return The result of load the XML at the Document or null if any error occurs */ public static Document loadXML(String xml) { try { if (xml.contains("<!ENTITY")) { throw new XMLEntityException("Detected use of ENTITY in XML, disabled to prevent XXE/XEE attacks"); } return convertStringToDocument(xml); } catch (XMLEntityException e) { LOGGER.debug("Load XML error due XMLEntityException.", e); } catch (Exception e) { LOGGER.debug("Load XML error: " + e.getMessage(), e); } return null; } /** * Extracts a node from the DOMDocument * * @param dom * The DOMDocument * @param query * Xpath Expression * @param context * Context Node (DomElement) * * @return DOMNodeList The queried node * * @throws XPathExpressionException */ public static NodeList query(Document dom, String query, Node context) throws XPathExpressionException { NodeList nodeList; XPath xpath = XPathFactory.newInstance().newXPath(); xpath.setNamespaceContext(new NamespaceContext() { @Override public String getNamespaceURI(String prefix) { String result = null; if (prefix.equals("samlp") || prefix.equals("samlp2")) { result = Constants.NS_SAMLP; } else if (prefix.equals("saml") || prefix.equals("saml2")) { result = Constants.NS_SAML; } else if (prefix.equals("ds")) { result = Constants.NS_DS; } else if (prefix.equals("xenc")) { result = Constants.NS_XENC; } else if (prefix.equals("md")) { result = Constants.NS_MD; } return result; } @Override public String getPrefix(String namespaceURI) { return null; } @SuppressWarnings("rawtypes") @Override public Iterator getPrefixes(String namespaceURI) { return null; } }); if (context == null) { nodeList = (NodeList) xpath.evaluate(query, dom, XPathConstants.NODESET); } else { nodeList = (NodeList) xpath.evaluate(query, context, XPathConstants.NODESET); } return nodeList; } /** * Extracts a node from the DOMDocument * * @param dom * The DOMDocument * @param query * Xpath Expression * * @return DOMNodeList The queried node * * @throws XPathExpressionException */ public static NodeList query(Document dom, String query) throws XPathExpressionException { return query(dom, query, null); } /** * This function attempts to validate an XML against the specified schema. * * @param xmlDocument * The XML document which should be validated * @param schemaUrl * The schema filename which should be used * * @return found errors after validation */ public static boolean validateXML(Document xmlDocument, URL schemaUrl) { try { if (xmlDocument == null) { throw new IllegalArgumentException("xmlDocument was null"); } Schema schema = SchemaFactory.loadFromUrl(schemaUrl); Validator validator = schema.newValidator(); if (JAXP_15_SUPPORTED) { // Prevent XXE attacks validator.setProperty(XMLConstants.ACCESS_EXTERNAL_DTD, ""); validator.setProperty(XMLConstants.ACCESS_EXTERNAL_SCHEMA, ""); } XMLErrorAccumulatorHandler errorAcumulator = new XMLErrorAccumulatorHandler(); validator.setErrorHandler(errorAcumulator); Source xmlSource = new DOMSource(xmlDocument); validator.validate(xmlSource); final boolean isValid = !errorAcumulator.hasError(); if (!isValid) { LOGGER.warn("Errors found when validating SAML response with schema: " + errorAcumulator.getErrorXML()); } return isValid; } catch (Exception e) { LOGGER.warn("Error executing validateXML: " + e.getMessage(), e); return false; } } /** * Converts an XML in string format in a Document object * * @param xmlStr * The XML string which should be converted * * @return the Document object * * @throws ParserConfigurationException * @throws SAXException * @throws IOException */ public static Document convertStringToDocument(String xmlStr) throws ParserConfigurationException, SAXException, IOException { DocumentBuilderFactory docfactory = DocumentBuilderFactory.newInstance(); docfactory.setNamespaceAware(true); // do not expand entity reference nodes docfactory.setExpandEntityReferences(false); docfactory.setAttribute("http://java.sun.com/xml/jaxp/properties/schemaLanguage", XMLConstants.W3C_XML_SCHEMA_NS_URI); // Add various options explicitly to prevent XXE attacks. // (adding try/catch around every setAttribute just in case a specific parser does not support it. try { // do not include external general entities docfactory.setAttribute("http://xml.org/sax/features/external-general-entities", Boolean.FALSE); } catch (Throwable e) {} try { // do not include external parameter entities or the external DTD subset docfactory.setAttribute("http://xml.org/sax/features/external-parameter-entities", Boolean.FALSE); } catch (Throwable e) {} try { docfactory.setAttribute("http://apache.org/xml/features/disallow-doctype-decl", Boolean.TRUE); } catch (Throwable e) {} try { docfactory.setAttribute("http://javax.xml.XMLConstants/feature/secure-processing", Boolean.TRUE); } catch (Throwable e) {} try { // ignore the external DTD completely docfactory.setAttribute("http://apache.org/xml/features/nonvalidating/load-external-dtd", Boolean.FALSE); } catch (Throwable e) {} try { // build the grammar but do not use the default attributes and attribute types information it contains docfactory.setAttribute("http://apache.org/xml/features/nonvalidating/load-dtd-grammar", Boolean.FALSE); } catch (Throwable e) {} try { docfactory.setFeature(XMLConstants.FEATURE_SECURE_PROCESSING, true); } catch (Throwable e) {} DocumentBuilder builder = docfactory.newDocumentBuilder(); Document doc = builder.parse(new InputSource(new StringReader(xmlStr))); // Loop through the doc and tag every element with an ID attribute // as an XML ID node. XPath xpath = XPathFactory.newInstance().newXPath(); XPathExpression expr; try { expr = xpath.compile("//*[@ID]"); NodeList nodeList = (NodeList) expr.evaluate(doc, XPathConstants.NODESET); for (int i = 0; i < nodeList.getLength(); i++) { Element elem = (Element) nodeList.item(i); Attr attr = (Attr) elem.getAttributes().getNamedItem("ID"); elem.setIdAttributeNode(attr, true); } } catch (XPathExpressionException e) { return null; } return doc; } /** * Converts an XML in Document format in a String * * @param doc * The Document object * @param c14n * If c14n transformation should be applied * * @return the Document object */ public static String convertDocumentToString(Document doc, Boolean c14n) { org.apache.xml.security.Init.init(); ByteArrayOutputStream baos = new ByteArrayOutputStream(); if (c14n) { XMLUtils.outputDOMc14nWithComments(doc, baos); } else { XMLUtils.outputDOM(doc, baos); } return Util.toStringUtf8(baos.toByteArray()); } /** * Converts an XML in Document format in a String without applying the c14n transformation * * @param doc * The Document object * * @return the Document object */ public static String convertDocumentToString(Document doc) { return convertDocumentToString(doc, false); } /** * Returns a certificate in String format (adding header and footer if required) * * @param cert * A x509 unformatted cert * @param heads * True if we want to include head and footer * * @return X509Certificate $x509 Formated cert */ public static String formatCert(String cert, Boolean heads) { String x509cert = StringUtils.EMPTY; if (cert != null) { x509cert = cert.replace("\\x0D", "").replace("\r", "").replace("\n", "").replace(" ", ""); if (!StringUtils.isEmpty(x509cert)) { x509cert = x509cert.replace("-----BEGINCERTIFICATE-----", "").replace("-----ENDCERTIFICATE-----", ""); if (heads) { x509cert = "-----BEGIN CERTIFICATE-----\n" + chunkString(x509cert, 64) + "-----END CERTIFICATE-----"; } } } return x509cert; } /** * Returns a private key (adding header and footer if required). * * @param key * A private key * @param heads * True if we want to include head and footer * * @return Formated private key */ public static String formatPrivateKey(String key, boolean heads) { String xKey = StringUtils.EMPTY; if (key != null) { xKey = key.replace("\\x0D", "").replace("\r", "").replace("\n", "").replace(" ", ""); if (!StringUtils.isEmpty(xKey)) { if (xKey.startsWith("-----BEGINPRIVATEKEY-----")) { xKey = xKey.replace("-----BEGINPRIVATEKEY-----", "").replace("-----ENDPRIVATEKEY-----", ""); if (heads) { xKey = "-----BEGIN PRIVATE KEY-----\n" + chunkString(xKey, 64) + "-----END PRIVATE KEY-----"; } } else { xKey = xKey.replace("-----BEGINRSAPRIVATEKEY-----", "").replace("-----ENDRSAPRIVATEKEY-----", ""); if (heads) { xKey = "-----BEGIN RSA PRIVATE KEY-----\n" + chunkString(xKey, 64) + "-----END RSA PRIVATE KEY-----"; } } } } return xKey; } /** * chunk a string * * @param str * The string to be chunked * @param chunkSize * The chunk size * * @return the chunked string */ private static String chunkString(String str, int chunkSize) { String newStr = StringUtils.EMPTY; int stringLength = str.length(); for (int i = 0; i < stringLength; i += chunkSize) { if (i + chunkSize > stringLength) { chunkSize = stringLength - i; } newStr += str.substring(i, chunkSize + i) + '\n'; } return newStr; } /** * Load X.509 certificate * * @param certString * certificate in string format * * @return Loaded Certificate. X509Certificate object * * @throws UnsupportedEncodingException * @throws CertificateException * */ public static X509Certificate loadCert(String certString) throws CertificateException, UnsupportedEncodingException { certString = formatCert(certString, true); X509Certificate cert; try { cert = (X509Certificate) CertificateFactory.getInstance("X.509").generateCertificate( new ByteArrayInputStream(certString.getBytes("utf-8"))); } catch (IllegalArgumentException e){ cert = null; } return cert; } /** * Load private key * * @param keyString * private key in string format * * @return Loaded private key. PrivateKey object * * @throws GeneralSecurityException * @throws IOException */ public static PrivateKey loadPrivateKey(String keyString) throws GeneralSecurityException, IOException { org.apache.xml.security.Init.init(); keyString = formatPrivateKey(keyString, false); keyString = chunkString(keyString, 64); KeyFactory kf = KeyFactory.getInstance("RSA"); PrivateKey privKey; try { byte[] encoded = Base64.decodeBase64(keyString); PKCS8EncodedKeySpec keySpec = new PKCS8EncodedKeySpec(encoded); privKey = (PrivateKey) kf.generatePrivate(keySpec); } catch(IllegalArgumentException e) { privKey = null; } return privKey; } /** * Calculates the fingerprint of a x509cert * * @param x509cert * x509 certificate * @param alg * Digest Algorithm * * @return the formated fingerprint */ public static String calculateX509Fingerprint(X509Certificate x509cert, String alg) { String fingerprint = StringUtils.EMPTY; try { byte[] dataBytes = x509cert.getEncoded(); if (alg == null || alg.isEmpty() || alg.equals("SHA-1")|| alg.equals("sha1")) { fingerprint = DigestUtils.sha1Hex(dataBytes); } else if (alg.equals("SHA-256") || alg .equals("sha256")) { fingerprint = DigestUtils.sha256Hex(dataBytes); } else if (alg.equals("SHA-384") || alg .equals("sha384")) { fingerprint = DigestUtils.sha384Hex(dataBytes); } else if (alg.equals("SHA-512") || alg.equals("sha512")) { fingerprint = DigestUtils.sha512Hex(dataBytes); } else { LOGGER.debug("Error executing calculateX509Fingerprint. alg " + alg + " not supported"); } } catch (Exception e) { LOGGER.debug("Error executing calculateX509Fingerprint: "+ e.getMessage(), e); } return fingerprint.toLowerCase(); } /** * Calculates the SHA-1 fingerprint of a x509cert * * @param x509cert * x509 certificate * * @return the SHA-1 formated fingerprint */ public static String calculateX509Fingerprint(X509Certificate x509cert) { return calculateX509Fingerprint(x509cert, "SHA-1"); } /** * Converts an X509Certificate in a well formated PEM string * * @param certificate * The public certificate * * @return the formated PEM string */ public static String convertToPem(X509Certificate certificate) { String pemCert = ""; try { Base64 encoder = new Base64(64); String cert_begin = "-----BEGIN CERTIFICATE-----\n"; String end_cert = "-----END CERTIFICATE-----"; byte[] derCert = certificate.getEncoded(); String pemCertPre = new String(encoder.encode(derCert)); pemCert = cert_begin + pemCertPre + end_cert; } catch (Exception e) { LOGGER.debug("Error converting certificate on PEM format: "+ e.getMessage(), e); } return pemCert; } /** * Loads a resource located at a relative path * * @param relativeResourcePath * Relative path of the resource * * @return the loaded resource in String format * * @throws IOException */ public static String getFileAsString(String relativeResourcePath) throws IOException { InputStream is = Util.class.getResourceAsStream("/" + relativeResourcePath); if (is == null) { throw new FileNotFoundException(relativeResourcePath); } try { ByteArrayOutputStream bytes = new ByteArrayOutputStream(); copyBytes(new BufferedInputStream(is), bytes); return bytes.toString("utf-8"); } finally { is.close(); } } private static void copyBytes(InputStream is, OutputStream bytes) throws IOException { int res = is.read(); while (res != -1) { bytes.write(res); res = is.read(); } } /** * Returns String Base64 decoded and inflated * * @param input * String input * * @return the base64 decoded and inflated string */ public static String base64decodedInflated(String input) { // Base64 decoder byte[] decoded = Base64.decodeBase64(input); // Inflater try { Inflater decompresser = new Inflater(true); decompresser.setInput(decoded); byte[] result = new byte[2048]; int resultLength = decompresser.inflate(result); decompresser.end(); String inflated = new String(result, 0, resultLength, "UTF-8"); return inflated; } catch (Exception e) { return new String(decoded); } } /** * Returns String Deflated and base64 encoded * * @param input * String input * * @return the deflated and base64 encoded string * @throws IOException */ public static String deflatedBase64encoded(String input) throws IOException { // Deflater ByteArrayOutputStream bytesOut = new ByteArrayOutputStream(); Deflater deflater = new Deflater(Deflater.DEFLATED, true); DeflaterOutputStream deflaterStream = new DeflaterOutputStream(bytesOut, deflater); deflaterStream.write(input.getBytes(Charset.forName("UTF-8"))); deflaterStream.finish(); // Base64 encoder return new String(Base64.encodeBase64(bytesOut.toByteArray())); } /** * Returns String base64 encoded * * @param input * Stream input * * @return the base64 encoded string */ public static String base64encoder(byte [] input) { return toStringUtf8(Base64.encodeBase64(input)); } /** * Returns String base64 encoded * * @param input * String input * * @return the base64 encoded string */ public static String base64encoder(String input) { return base64encoder(toBytesUtf8(input)); } /** * Returns String base64 decoded * * @param input * Stream input * * @return the base64 decoded bytes */ public static byte[] base64decoder(byte [] input) { return Base64.decodeBase64(input); } /** * Returns String base64 decoded * * @param input * String input * * @return the base64 decoded bytes */ public static byte[] base64decoder(String input) { return base64decoder(toBytesUtf8(input)); } /** * Returns String URL encoded * * @param input * String input * * @return the URL encoded string */ public static String urlEncoder(String input) { if (input != null) { try { return URLEncoder.encode(input, "UTF-8"); } catch (UnsupportedEncodingException e) { LOGGER.error("URL encoder error.", e); throw new IllegalArgumentException(); } } else { return null; } } /** * Returns String URL decoded * * @param input * URL encoded input * * @return the URL decoded string */ public static String urlDecoder(String input) { if (input != null) { try { return URLDecoder.decode(input, "UTF-8"); } catch (UnsupportedEncodingException e) { LOGGER.error("URL decoder error.", e); throw new IllegalArgumentException(); } } else { return null; } } /** * Generates a signature from a string * * @param text * The string we should sign * @param key * The private key to sign the string * @param signAlgorithm * Signature algorithm method * * @return the signature * * @throws NoSuchAlgorithmException * @throws InvalidKeyException * @throws SignatureException */ public static byte[] sign(String text, PrivateKey key, String signAlgorithm) throws NoSuchAlgorithmException, InvalidKeyException, SignatureException { org.apache.xml.security.Init.init(); if (signAlgorithm == null) { signAlgorithm = Constants.RSA_SHA1; } Signature instance = Signature.getInstance(signatureAlgConversion(signAlgorithm)); instance.initSign(key); instance.update(text.getBytes()); byte[] signature = instance.sign(); return signature; } /** * Converts Signature algorithm method name * * @param sign * signature algorithm method * * @return the converted signature name */ public static String signatureAlgConversion(String sign) { String convertedSignatureAlg = ""; if (sign == null) { convertedSignatureAlg = "SHA1withRSA"; } else if (sign.equals(Constants.DSA_SHA1)) { convertedSignatureAlg = "SHA1withDSA"; } else if (sign.equals(Constants.RSA_SHA256)) { convertedSignatureAlg = "SHA256withRSA"; } else if (sign.equals(Constants.RSA_SHA384)) { convertedSignatureAlg = "SHA384withRSA"; } else if (sign.equals(Constants.RSA_SHA512)) { convertedSignatureAlg = "SHA512withRSA"; } else { convertedSignatureAlg = "SHA1withRSA"; } return convertedSignatureAlg; } /** * Validate the signature pointed to by the xpath * * @param doc The document we should validate * @param cert The public certificate * @param fingerprint The fingerprint of the public certificate * @param alg The signature algorithm method * @param xpath the xpath of the ds:Signture node to validate * * @return True if the signature exists and is valid, false otherwise. */ public static boolean validateSign(final Document doc, final X509Certificate cert, final String fingerprint, final String alg, final String xpath) { try { final NodeList signatures = query(doc, xpath); return signatures.getLength() == 1 && validateSignNode(signatures.item(0), cert, fingerprint, alg); } catch (XPathExpressionException e) { LOGGER.warn("Failed to find signature nodes", e); } return false; } /** * Validate signature (Metadata). * * @param doc * The document we should validate * @param cert * The public certificate * @param fingerprint * The fingerprint of the public certificate * @param alg * The signature algorithm method * * @return True if the sign is valid, false otherwise. */ public static Boolean validateMetadataSign(Document doc, X509Certificate cert, String fingerprint, String alg) { NodeList signNodesToValidate; try { signNodesToValidate = query(doc, "/md:EntitiesDescriptor/ds:Signature"); if (signNodesToValidate.getLength() == 0) { signNodesToValidate = query(doc, "/md:EntityDescriptor/ds:Signature"); if (signNodesToValidate.getLength() == 0) { signNodesToValidate = query(doc, "/md:EntityDescriptor/md:SPSSODescriptor/ds:Signature|/md:EntityDescriptor/IDPSSODescriptor/ds:Signature"); } } if (signNodesToValidate.getLength() > 0) { for (int i = 0; i < signNodesToValidate.getLength(); i++) { Node signNode = signNodesToValidate.item(i); if (!validateSignNode(signNode, cert, fingerprint, alg)) { return false; } } return true; } } catch (XPathExpressionException e) { LOGGER.warn("Failed to find signature nodes", e); } return false; } /** * Validate signature of the Node. * * @param signNode * The document we should validate * @param cert * The public certificate * @param fingerprint * The fingerprint of the public certificate * @param alg * The signature algorithm method * * @return True if the sign is valid, false otherwise. */ public static Boolean validateSignNode(Node signNode, X509Certificate cert, String fingerprint, String alg) { Boolean res = false; try { org.apache.xml.security.Init.init(); Element sigElement = (Element) signNode; XMLSignature signature = new XMLSignature(sigElement, "", true); if (cert != null) { res = signature.checkSignatureValue(cert); } else { KeyInfo keyInfo = signature.getKeyInfo(); if (keyInfo != null && keyInfo.containsX509Data()) { X509Certificate providedCert = keyInfo.getX509Certificate(); if (fingerprint.equals(calculateX509Fingerprint(providedCert, alg))) { res = signature.checkSignatureValue(providedCert); } } } } catch (Exception e) { LOGGER.warn("Error executing validateSignNode: " + e.getMessage(), e); } return res; } /** * Decrypt an encrypted element. * * @param encryptedDataElement * The encrypted element. * @param inputKey * The private key to decrypt. */ public static void decryptElement(Element encryptedDataElement, PrivateKey inputKey) { try { org.apache.xml.security.Init.init(); XMLCipher xmlCipher = XMLCipher.getInstance(); xmlCipher.init(XMLCipher.DECRYPT_MODE, null); /* Check if we have encryptedData with a KeyInfo that contains a RetrievalMethod to obtain the EncryptedKey. xmlCipher is not able to handle that so we move the EncryptedKey inside the KeyInfo element and replacing the RetrievalMethod. */ NodeList keyInfoInEncData = encryptedDataElement.getElementsByTagNameNS(Constants.NS_DS, "KeyInfo"); if (keyInfoInEncData.getLength() == 0) { throw new ValidationError("No KeyInfo inside EncryptedData element", ValidationError.KEYINFO_NOT_FOUND_IN_ENCRYPTED_DATA); } NodeList childs = keyInfoInEncData.item(0).getChildNodes(); for (int i=0; i < childs.getLength(); i++) { if (childs.item(i).getLocalName() != null && childs.item(i).getLocalName().equals("RetrievalMethod")) { Element retrievalMethodElem = (Element)childs.item(i); if (!retrievalMethodElem.getAttribute("Type").equals("http://www.w3.org/2001/04/xmlenc#EncryptedKey")) { throw new ValidationError("Unsupported Retrieval Method found", ValidationError.UNSUPPORTED_RETRIEVAL_METHOD); } String uri = retrievalMethodElem.getAttribute("URI").substring(1); NodeList encryptedKeyNodes = ((Element) encryptedDataElement.getParentNode()).getElementsByTagNameNS(Constants.NS_XENC, "EncryptedKey"); for (int j=0; j < encryptedKeyNodes.getLength(); j++) { if (((Element)encryptedKeyNodes.item(j)).getAttribute("Id").equals(uri)) { keyInfoInEncData.item(0).replaceChild(encryptedKeyNodes.item(j), childs.item(i)); } } } } xmlCipher.setKEK(inputKey); xmlCipher.doFinal(encryptedDataElement.getOwnerDocument(), encryptedDataElement, false); } catch (Exception e) { LOGGER.warn("Error executing decryption: " + e.getMessage(), e); } } /** * Clone a Document object. * * @param source * The Document object to be cloned. * * @return the clone of the Document object * * @throws ParserConfigurationException */ public static Document copyDocument(Document source) throws ParserConfigurationException { DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance(); dbf.setNamespaceAware(true); DocumentBuilder db = dbf.newDocumentBuilder(); Node originalRoot = source.getDocumentElement(); Document copiedDocument = db.newDocument(); Node copiedRoot = copiedDocument.importNode(originalRoot, true); copiedDocument.appendChild(copiedRoot); return copiedDocument; } /** * Signs the Document using the specified signature algorithm with the private key and the public certificate. * * @param document * The document to be signed * @param key * The private key * @param certificate * The public certificate * @param signAlgorithm * Signature Algorithm * * @return the signed document in string format * * @throws XMLSecurityException * @throws XPathExpressionException */ public static String addSign(Document document, PrivateKey key, X509Certificate certificate, String signAlgorithm) throws XMLSecurityException, XPathExpressionException { org.apache.xml.security.Init.init(); // Check arguments. if (document == null) { throw new IllegalArgumentException("Provided document was null"); } if (document.getDocumentElement() == null) { throw new IllegalArgumentException("The Xml Document has no root element."); } if (key == null) { throw new IllegalArgumentException("Provided key was null"); } if (certificate == null) { throw new IllegalArgumentException("Provided certificate was null"); } if (signAlgorithm == null || signAlgorithm.isEmpty()) { signAlgorithm = Constants.RSA_SHA1; } // document.normalizeDocument(); String c14nMethod = Constants.C14N_WC; // Signature object XMLSignature sig = new XMLSignature(document, null, signAlgorithm, c14nMethod); // Including the signature into the document before sign, because // this is an envelop signature Element root = document.getDocumentElement(); document.setXmlStandalone(false); // If Issuer, locate Signature after Issuer, Otherwise as first child. NodeList issuerNodes = Util.query(document, "//saml:Issuer", null); if (issuerNodes.getLength() > 0) { Node issuer = issuerNodes.item(0); root.insertBefore(sig.getElement(), issuer.getNextSibling()); } else { root.insertBefore(sig.getElement(), root.getFirstChild()); } String id = root.getAttribute("ID"); String reference = id; if (!id.isEmpty()) { root.setIdAttributeNS(null, "ID", true); reference = "#" + id; } // Create the transform for the document Transforms transforms = new Transforms(document); transforms.addTransform(Constants.ENVSIG); //transforms.addTransform(Transforms.TRANSFORM_C14N_OMIT_COMMENTS); transforms.addTransform(c14nMethod); sig.addDocument(reference, transforms, Constants.SHA1); // Add the certification info sig.addKeyInfo(certificate); // Sign the document sig.sign(key); return convertDocumentToString(document, true); } /** * Signs a Node using the specified signature algorithm with the private key and the public certificate. * * @param node * The Node to be signed * @param key * The private key * @param certificate * The public certificate * @param signAlgorithm * Signature Algorithm * * @return the signed document in string format * * @throws ParserConfigurationException * @throws XMLSecurityException * @throws XPathExpressionException */ public static String addSign(Node node, PrivateKey key, X509Certificate certificate, String signAlgorithm) throws ParserConfigurationException, XPathExpressionException, XMLSecurityException { // Check arguments. if (node == null) { throw new IllegalArgumentException("Provided node was null"); } DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance(); dbf.setNamespaceAware(true); Document doc = dbf.newDocumentBuilder().newDocument(); Node newNode = doc.importNode(node, true); doc.appendChild(newNode); return addSign(doc, key, certificate, signAlgorithm); } /** * Validates signed binary data (Used to validate GET Signature). * * @param signedQuery * The element we should validate * @param signature * The signature that will be validate * @param cert * The public certificate * @param signAlg * Signature Algorithm * * @return the signed document in string format * * @throws NoSuchAlgorithmException * @throws NoSuchProviderException * @throws InvalidKeyException * @throws SignatureException */ public static Boolean validateBinarySignature(String signedQuery, byte[] signature, X509Certificate cert, String signAlg) throws NoSuchAlgorithmException, NoSuchProviderException, InvalidKeyException, SignatureException { Boolean valid = false; try { org.apache.xml.security.Init.init(); String convertedSigAlg = signatureAlgConversion(signAlg); Signature sig = Signature.getInstance(convertedSigAlg); //, provider); sig.initVerify(cert.getPublicKey()); sig.update(signedQuery.getBytes()); valid = sig.verify(signature); } catch (Exception e) { LOGGER.warn("Error executing validateSign: " + e.getMessage(), e); } return valid; } /** * Generates a nameID. * * @param value * The value * @param spnq * SP Name Qualifier * @param format * SP Format * @param cert * IdP Public certificate to encrypt the nameID * * @return Xml contained in the document. */ public static String generateNameId(String value, String spnq, String format, X509Certificate cert) { String res = null; try { DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance(); dbf.setNamespaceAware(true); Document doc = dbf.newDocumentBuilder().newDocument(); Element nameId = doc.createElement("saml:NameID"); if (spnq != null && !spnq.isEmpty()) { nameId.setAttribute("SPNameQualifier", spnq); } if (format != null && !format.isEmpty()) { nameId.setAttribute("Format", format); } nameId.appendChild(doc.createTextNode(value)); doc.appendChild(nameId); if (cert != null) { // We generate a symmetric key Key symmetricKey = generateSymmetricKey(); // cipher for encrypt the data XMLCipher xmlCipher = XMLCipher.getInstance(Constants.AES128_CBC); xmlCipher.init(XMLCipher.ENCRYPT_MODE, symmetricKey); // cipher for encrypt the symmetric key XMLCipher keyCipher = XMLCipher.getInstance(Constants.RSA_1_5); keyCipher.init(XMLCipher.WRAP_MODE, cert.getPublicKey()); // encrypt the symmetric key EncryptedKey encryptedKey = keyCipher.encryptKey(doc, symmetricKey); // Add keyinfo inside the encrypted data EncryptedData encryptedData = xmlCipher.getEncryptedData(); KeyInfo keyInfo = new KeyInfo(doc); keyInfo.add(encryptedKey); encryptedData.setKeyInfo(keyInfo); // Encrypt the actual data xmlCipher.doFinal(doc, nameId, false); // Building the result res = "<saml:EncryptedID>" + convertDocumentToString(doc) + "</saml:EncryptedID>"; } else { res = convertDocumentToString(doc); } } catch (Exception e) { LOGGER.error("Error executing generateNameId: " + e.getMessage(), e); } return res; } /** * Generates a nameID. * * @param value * The value * @param spnq * SP Name Qualifier * @param format * SP Format * * @return Xml contained in the document. */ public static String generateNameId(String value, String spnq, String format) { return generateNameId(value, spnq, format, null); } /** * Method to generate a symmetric key for encryption * * @return the symmetric key * * @throws Exception */ private static SecretKey generateSymmetricKey() throws Exception { KeyGenerator keyGenerator = KeyGenerator.getInstance("AES"); keyGenerator.init(128); return keyGenerator.generateKey(); } /** * Generates a unique string (used for example as ID of assertions) * * @return A unique string */ public static String generateUniqueID() { return UNIQUE_ID_PREFIX + UUID.randomUUID(); } /** * Interprets a ISO8601 duration value relative to a current time timestamp. * * @param duration * The duration, as a string. * * @return int The new timestamp, after the duration is applied. * * @throws IllegalArgumentException */ public static long parseDuration(String duration) throws IllegalArgumentException { TimeZone timeZone = DateTimeZone.UTC.toTimeZone(); return parseDuration(duration, Calendar.getInstance(timeZone).getTimeInMillis() / 1000); } /** * Interprets a ISO8601 duration value relative to a given timestamp. * * @param durationString * The duration, as a string. * @param timestamp * The unix timestamp we should apply the duration to. * * @return the new timestamp, after the duration is applied In Seconds. * * @throws IllegalArgumentException */ public static long parseDuration(String durationString, long timestamp) throws IllegalArgumentException { boolean haveMinus = false; if (durationString.startsWith("-")) { durationString = durationString.substring(1); haveMinus = true; } PeriodFormatter periodFormatter = ISOPeriodFormat.standard().withLocale(new Locale("UTC")); Period period = periodFormatter.parsePeriod(durationString); DateTime dt = new DateTime(timestamp * 1000, DateTimeZone.UTC); DateTime result = null; if (haveMinus) { result = dt.minus(period); } else { result = dt.plus(period); } return result.getMillis() / 1000; } /** * @return the unix timestamp that matches the current time. */ public static Long getCurrentTimeStamp() { DateTime currentDate = new DateTime(DateTimeZone.UTC); return currentDate.getMillis() / 1000; } /** * Compare 2 dates and return the the earliest * * @param cacheDuration * The duration, as a string. * @param validUntil * The valid until date, as a string * * @return the expiration time (timestamp format). */ public static long getExpireTime(String cacheDuration, String validUntil) { long expireTime = 0; try { if (cacheDuration != null && !StringUtils.isEmpty(cacheDuration)) { expireTime = parseDuration(cacheDuration); } if (validUntil != null && !StringUtils.isEmpty(validUntil)) { DateTime dt = Util.parseDateTime(validUntil); long validUntilTimeInt = dt.getMillis() / 1000; if (expireTime == 0 || expireTime > validUntilTimeInt) { expireTime = validUntilTimeInt; } } } catch (Exception e) { LOGGER.error("Error executing getExpireTime: " + e.getMessage(), e); } return expireTime; } /** * Compare 2 dates and return the the earliest * * @param cacheDuration * The duration, as a string. * @param validUntil * The valid until date, as a timestamp * * @return the expiration time (timestamp format). */ public static long getExpireTime(String cacheDuration, long validUntil) { long expireTime = 0; try { if (cacheDuration != null && !StringUtils.isEmpty(cacheDuration)) { expireTime = parseDuration(cacheDuration); } if (expireTime == 0 || expireTime > validUntil) { expireTime = validUntil; } } catch (Exception e) { LOGGER.error("Error executing getExpireTime: " + e.getMessage(), e); } return expireTime; } /** * Create string form time In Millis with format yyyy-MM-ddTHH:mm:ssZ * * @param timeInMillis * The time in Millis * * @return string with format yyyy-MM-ddTHH:mm:ssZ */ public static String formatDateTime(long timeInMillis) { return DATE_TIME_FORMAT.print(timeInMillis); } /** * Create string form time In Millis with format yyyy-MM-ddTHH:mm:ssZ * * @param time * The time * @param millis * Defines if the time is in Millis * * @return string with format yyyy-MM-ddTHH:mm:ssZ */ public static String formatDateTime(long time, boolean millis) { if (millis) { return DATE_TIME_FORMAT_MILLS.print(time); } else { return formatDateTime(time); } } /** * Create calendar form string with format yyyy-MM-ddTHH:mm:ssZ // yyyy-MM-ddTHH:mm:ss.SSSZ * * @param dateTime * string with format yyyy-MM-ddTHH:mm:ssZ // yyyy-MM-ddTHH:mm:ss.SSSZ * * @return datetime */ public static DateTime parseDateTime(String dateTime) { DateTime parsedData = null; try { parsedData = DATE_TIME_FORMAT.parseDateTime(dateTime); } catch(Exception e) { return DATE_TIME_FORMAT_MILLS.parseDateTime(dateTime); } return parsedData; } private static String toStringUtf8(byte[] bytes) { try { return new String(bytes, "UTF-8"); } catch (UnsupportedEncodingException e) { throw new IllegalStateException(e); } } private static byte[] toBytesUtf8(String str) { try { return str.getBytes("UTF-8"); } catch (UnsupportedEncodingException e) { throw new IllegalStateException(e); } } }